{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "67b741df-7c4b-4fdc-949b-ca97492fa288",
   "metadata": {},
   "source": [
    "## Analyzing Sparse Autoencoders (SAEs) from [Gemma Scope](https://colab.research.google.com/drive/17dQFYUYnuKnP6OwQPH9v_GSYUW5aj-Rp)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82c139fd-2f1b-48fb-9109-2737565d2323",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/stanfordnlp/pyvene/blob/main/tutorials/basic_tutorials/Sparse_Autoencoder.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e5d14f0a-b02d-4d1f-863a-dbb1e475e264",
   "metadata": {},
   "outputs": [],
   "source": [
    "__author__ = \"Zhengxuan Wu\"\n",
    "__version__ = \"09/23/2024\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1bbb6ff3-b76b-4672-a8bc-1a9324c5e3de",
   "metadata": {},
   "source": [
    "### Overview\n",
    "\n",
    "This tutorial aims to **(1) reproduce** and **(2) extend** some of the results in the Gemma Scope (SAE) tutorial in [notebook](https://colab.research.google.com/drive/17dQFYUYnuKnP6OwQPH9v_GSYUW5aj-Rp) for interpreting latents of SAEs. This tutorial also shows basic model steering with SAEs. This notebook is built as a show-case for the Gemma 2 2B model as well as its SAEs. However, this tutorial can be extended to any other model types and their SAEs. \n",
    "\n",
    "\n",
    "**Note**: This tutorial assumes SAEs are pretrained separately."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94a4da6f-d078-4b64-9fd0-d79ece35d3f1",
   "metadata": {},
   "source": [
    "### Set-up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dd197c1f-71b5-4379-a9dd-2f6ff27083f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-310/lib/python3.10/site-packages/transformers/utils/hub.py:127: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    # This library is our indicator that the required installs\n",
    "    # need to be done.\n",
    "    import pyvene\n",
    "\n",
    "except ModuleNotFoundError:\n",
    "    !pip install git+https://github.com/stanfordnlp/pyvene.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "209bfc46-7685-4e66-975f-3280ed516b52",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyvene import (\n",
    "    ConstantSourceIntervention,\n",
    "    SourcelessIntervention,\n",
    "    TrainableIntervention,\n",
    "    DistributedRepresentationIntervention,\n",
    "    CollectIntervention,\n",
    "    JumpReLUAutoencoderIntervention\n",
    ")\n",
    "\n",
    "from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer\n",
    "from huggingface_hub import hf_hub_download, notebook_login\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "# If you haven't login, you need to do so.\n",
    "# notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89e7073e-4a64-49c9-8be2-df7926e08332",
   "metadata": {},
   "source": [
    "### Loading the model and its tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a6e7e7fb-5e73-4711-b378-bc1b04ab1e7f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "192a06afdbdc4c868bc6d20677b3dd38",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3e088f2b3808489bac70b9aeb5ae73f0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "torch.set_grad_enabled(False) # avoid blowing up mem\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    \"google/gemma-2-2b\", # google/gemma-2b-it\n",
    "    device_map='auto',\n",
    ")\n",
    "tokenizer =  AutoTokenizer.from_pretrained(\"google/gemma-2-2b\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f6af820-a819-4f8e-ac13-6063b8e47e5d",
   "metadata": {},
   "source": [
    "We give it the prompt \"Would you be able to travel through time using a wormhole?\" and print the generated output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "95653dcc-876e-4419-a014-d0308ce12cef",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[     2,  18925,    692,    614,   3326,    577,   5056,   1593,   1069,\n",
      "           2177,    476,  47420,  18216, 235336]], device='cuda:0')\n",
      "<bos>Would you be able to travel through time using a wormhole?\n",
      "\n",
      "[Answer 1]\n",
      "\n",
      "Yes, you can travel through time using a wormhole.\n",
      "\n",
      "A wormhole is a theoretical object that connects two points in space-time. It is a tunnel through space-time that allows objects to travel from\n"
     ]
    }
   ],
   "source": [
    "# The input text\n",
    "prompt = \"Would you be able to travel through time using a wormhole?\"\n",
    "\n",
    "# Use the tokenizer to convert it to tokens. Note that this implicitly adds a special \"Beginning of Sequence\" or <bos> token to the start\n",
    "inputs = tokenizer.encode(prompt, return_tensors=\"pt\", add_special_tokens=True).to(\"cuda\")\n",
    "print(inputs)\n",
    "\n",
    "# Pass it in to the model and generate text\n",
    "outputs = model.generate(input_ids=inputs, max_new_tokens=50)\n",
    "print(tokenizer.decode(outputs[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71dd281a-562a-4382-8631-43ece42498ce",
   "metadata": {},
   "source": [
    "### Loading a SAE, and create SAE interventions\n",
    "\n",
    "`pyvene` can load SAEs as interventions for analyzing latents as well as model steering."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d490a50c-a1cd-4def-90c2-cd6bfe67266f",
   "metadata": {},
   "outputs": [],
   "source": [
    "LAYER = 20\n",
    "path_to_params = hf_hub_download(\n",
    "    repo_id=\"google/gemma-scope-2b-pt-res\",\n",
    "    filename=f\"layer_{LAYER}/width_16k/average_l0_71/params.npz\",\n",
    "    force_download=False,\n",
    ")\n",
    "params = np.load(path_to_params)\n",
    "pt_params = {k: torch.from_numpy(v).cuda() for k, v in params.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e61be8a-257f-4b3d-8112-cb657e9cce40",
   "metadata": {},
   "source": [
    "### Implementing SAEs as `pyvene`-native Interventions\n",
    "\n",
    "Create a `pyvene`-native intervention for SAEs to collect latent collection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d785071c-e3ed-4940-92c7-aa8b04df71d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "class JumpReLUSAECollectIntervention(\n",
    "    CollectIntervention\n",
    "):\n",
    "  \"\"\"Collect activations\"\"\"\n",
    "  def __init__(self, **kwargs):\n",
    "    # Note that we initialise these to zeros because we're loading in pre-trained weights.\n",
    "    # If you want to train your own SAEs then we recommend using blah\n",
    "    super().__init__(**kwargs, keep_last_dim=True)\n",
    "    self.W_enc = nn.Parameter(torch.zeros(self.embed_dim, kwargs[\"low_rank_dimension\"]))\n",
    "    self.W_dec = nn.Parameter(torch.zeros(kwargs[\"low_rank_dimension\"], self.embed_dim))\n",
    "    self.threshold = nn.Parameter(torch.zeros(kwargs[\"low_rank_dimension\"]))\n",
    "    self.b_enc = nn.Parameter(torch.zeros(kwargs[\"low_rank_dimension\"]))\n",
    "    self.b_dec = nn.Parameter(torch.zeros(self.embed_dim))\n",
    "\n",
    "  def encode(self, input_acts):\n",
    "    pre_acts = input_acts @ self.W_enc + self.b_enc\n",
    "    mask = (pre_acts > self.threshold)\n",
    "    acts = mask * torch.nn.functional.relu(pre_acts)\n",
    "    return acts\n",
    "\n",
    "  def forward(self, base, source=None, subspaces=None):\n",
    "    acts = self.encode(base)\n",
    "\n",
    "    return acts"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "057aa550-ad6a-4c81-ab37-e732eb202270",
   "metadata": {},
   "source": [
    "### Running the model with SAE to collect activations with `pyvene` APIs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "0accd2d7-b0db-4c21-a27f-b6f9a45f8cb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "sae = JumpReLUSAECollectIntervention(\n",
    "    embed_dim=params['W_enc'].shape[0],\n",
    "    low_rank_dimension=params['W_enc'].shape[1]\n",
    ")\n",
    "sae.load_state_dict(pt_params, strict=False)\n",
    "sae.cuda()\n",
    "\n",
    "# add the intervention to the model computation graph via the config\n",
    "pv_model = pyvene.IntervenableModel({\n",
    "   \"component\": f\"model.layers[{LAYER}].output\",\n",
    "   \"intervention\": sae}, model=model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "14eb3d4e-98ab-47ee-a842-0736407e8c0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "sae_acts = pv_model.forward(\n",
    "    {\"input_ids\": inputs}, return_dict=True).collected_activations[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "b769de21-57d8-4e6b-842a-636be784bb5d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([7017,   47,   65,   70,   55,   72,   65,   75,   80,   72,   68,   93,\n",
       "          86,   89], device='cuda:0')"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"\"\"\n",
    "Results (from Gemma Scope) should be:\n",
    "tensor([[7017,   47,   65,   70,   55,   72,   65,   75,   80,   72,   68,   93,\n",
    "           86,   89]], device='cuda:0')\n",
    "\"\"\"\n",
    "(sae_acts > 1).sum(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "30d5196f-c5bc-4a13-9363-3b287b9117b7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 6631,  5482, 10376,  1670, 11023,  7562,  9407,  8399, 12935, 10004,\n",
       "        10004, 10004, 12935,  3442], device='cuda:0')"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"\"\"\n",
    "Results (from Gemma Scope) should be:\n",
    "tensor([[ 6631,  5482, 10376,  1670, 11023,  7562,  9407,  8399, 12935, 10004,\n",
    "         10004, 10004, 12935,  3442]], device='cuda:0')\n",
    "\"\"\"\n",
    "values, inds = sae_acts.max(-1)\n",
    "\n",
    "inds"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31fdca77-5448-45ef-9816-194c610d919d",
   "metadata": {},
   "source": [
    "### Gemma-2-2B-it steering with Gemma-2-2B SAEs\n",
    "\n",
    "We could also try to steer Gemma-2-2B-it by overloading Gemma-2-2B SAE, and see if it works."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "fc55f530-1e55-4dc4-8cbc-6d26c9f9a0ab",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0574b56084114b98bcd5d0817960a78c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c135c09c7d6d43e5a874b0f7caa92b5b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "torch.set_grad_enabled(False) # avoid blowing up mem\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    \"google/gemma-2-2b-it\", # google/gemma-2b-it\n",
    "    device_map='auto',\n",
    ")\n",
    "tokenizer =  AutoTokenizer.from_pretrained(\"google/gemma-2-2b-it\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20af0809-1a1d-4279-894e-52070e3de337",
   "metadata": {},
   "source": [
    "### Implementing SAEs as `pyvene`-native Interventions for model steering\n",
    "\n",
    "The `subspace` notation built in to `pyvene` let us to steer models by intervening on different features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "id": "043ffb42-f1d2-4b58-99a1-525dfbed45d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class JumpReLUSAESteeringIntervention(\n",
    "    SourcelessIntervention,\n",
    "    TrainableIntervention, \n",
    "    DistributedRepresentationIntervention\n",
    "):\n",
    "  def __init__(self, **kwargs):\n",
    "    # Note that we initialise these to zeros because we're loading in pre-trained weights.\n",
    "    # If you want to train your own SAEs then we recommend using blah\n",
    "    super().__init__(**kwargs, keep_last_dim=True)\n",
    "    self.W_enc = nn.Parameter(torch.zeros(self.embed_dim, kwargs[\"low_rank_dimension\"]))\n",
    "    self.W_dec = nn.Parameter(torch.zeros(kwargs[\"low_rank_dimension\"], self.embed_dim))\n",
    "    self.threshold = nn.Parameter(torch.zeros(kwargs[\"low_rank_dimension\"]))\n",
    "    self.b_enc = nn.Parameter(torch.zeros(kwargs[\"low_rank_dimension\"]))\n",
    "    self.b_dec = nn.Parameter(torch.zeros(self.embed_dim))\n",
    "\n",
    "  def encode(self, input_acts):\n",
    "    pre_acts = input_acts @ self.W_enc + self.b_enc\n",
    "    mask = (pre_acts > self.threshold)\n",
    "    acts = mask * torch.nn.functional.relu(pre_acts)\n",
    "    return acts\n",
    "\n",
    "  def decode(self, acts):\n",
    "    return acts @ self.W_dec + self.b_dec\n",
    "\n",
    "  def forward(self, base, source=None, subspaces=None):\n",
    "    steering_vec = torch.tensor(subspaces[\"mag\"]) * self.W_dec[subspaces[\"idx\"]]\n",
    "    return base + steering_vec"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c8597e2-69fd-4dc1-b120-af4974a75e79",
   "metadata": {},
   "source": [
    "Loading the Gemma base model SAE weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "id": "f99d24bc-22a8-4822-8240-4647b27b6d41",
   "metadata": {},
   "outputs": [],
   "source": [
    "sae = JumpReLUSAESteeringIntervention(\n",
    "    embed_dim=params['W_enc'].shape[0],\n",
    "    low_rank_dimension=params['W_enc'].shape[1]\n",
    ")\n",
    "sae.load_state_dict(pt_params, strict=False)\n",
    "sae.cuda()\n",
    "\n",
    "# add the intervention to the model computation graph via the config\n",
    "pv_model = pyvene.IntervenableModel({\n",
    "   \"component\": f\"model.layers[{LAYER}].output\",\n",
    "   \"intervention\": sae}, model=model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "id": "971f8a64-2bbb-47bb-9958-50484b5d9e3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"Which dog breed do people think is cuter, poodle or doodle?\"\n",
    "\n",
    "prompt = tokenizer(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
    "_, reft_response = pv_model.generate(\n",
    "    prompt, unit_locations=None, intervene_on_prompt=True, \n",
    "    subspaces=[{\"idx\": 10004, \"mag\": 100.0}],\n",
    "    max_new_tokens=128, do_sample=True, early_stopping=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "id": "60fe187c-bafe-459c-ad8a-f064d74b2ae3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Which dog breed do people think is cuter, poodle or doodle? \n",
      "\n",
      "It really depends on personal preference, but it's often a subjective matter. \n",
      "\n",
      "Here's a bit about each, to help you decide:\n",
      "\n",
      "**Poodles:**\n",
      "\n",
      "* Origin: France\n",
      "* Types: Standard, Miniature, Toy\n",
      "* Known for: Curly, hypoallergenic fur; intelligence and trainability.\n",
      "* Appearance: Classic, distinguished look with a flowing coat and well-defined facial features.\n",
      "\n",
      "**Doodles:**\n",
      "\n",
      "* Origin (general) Space-travel, time-travel or a blend - depending on the specific dog's ancestry.  The term is used across a variety of breeds.\n"
     ]
    }
   ],
   "source": [
    "print(tokenizer.decode(reft_response[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b863a941-f297-4ebc-9008-62da60fa89d7",
   "metadata": {},
   "source": [
    "**Here you go: a \"Space-travel, time-travel\" Doodle!**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22cabe19-2c2f-46c7-a631-d0b40fca5308",
   "metadata": {},
   "source": [
    "### Interchange intervention with JumpReLU SAEs.\n",
    "\n",
    "You can also swap values between examples for a specific latent dimension. However, since SAE usually maps a concpet to 1D subspace, swapping between examples and resetting the scalar to another value are similar.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4f23b199-ca01-4676-9a2d-61b24b96dc2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "sae = JumpReLUAutoencoderIntervention(\n",
    "    embed_dim=params['W_enc'].shape[0],\n",
    "    low_rank_dimension=params['W_enc'].shape[1]\n",
    ")\n",
    "sae.load_state_dict(pt_params, strict=False)\n",
    "sae.cuda()\n",
    "\n",
    "# add the intervention to the model computation graph via the config\n",
    "pv_model = pyvene.IntervenableModel({\n",
    "   \"component\": f\"model.layers[{LAYER}].output\",\n",
    "   \"intervention\": sae}, model=model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9dbe3883-3588-45fe-91bf-aeb075dea642",
   "metadata": {},
   "outputs": [],
   "source": [
    "base = tokenizer(\n",
    "    \"Which dog breed do people think is cuter, poodle or doodle?\", \n",
    "    return_tensors=\"pt\").to(\"cuda\")\n",
    "source = tokenizer(\n",
    "    \"Origin (general) Space-travel, time-travel\", \n",
    "    return_tensors=\"pt\").to(\"cuda\")\n",
    "\n",
    "# run an interchange intervention \n",
    "original_outputs, intervened_outputs = pv_model(\n",
    "  # the base input\n",
    "  base=base, \n",
    "  # the source input\n",
    "  sources=source, \n",
    "  # the location to intervene (swap last tokens)\n",
    "  unit_locations={\"sources->base\": (11, 14)},\n",
    "  # the SAE latent dimension mapping to the time travel concept (\"10004\")\n",
    "  subspaces=[10004],\n",
    "  output_original_output=True\n",
    ")\n",
    "logits_diff = intervened_outputs.logits[:,-1] - original_outputs.logits[:,-1]\n",
    "values, indices = logits_diff.topk(k=10, sorted=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "57b8c19a-c73f-47e5-b7f3-6b9353802a96",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "** topk logits diff **\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "['PhysRevD',\n",
       " ' transporting',\n",
       " ' teleport',\n",
       " ' space',\n",
       " ' transit',\n",
       " ' transported',\n",
       " ' transporter',\n",
       " ' transpor',\n",
       " ' multiverse',\n",
       " ' universes']"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(\"** topk logits diff **\")\n",
    "tokenizer.batch_decode(indices[0].unsqueeze(dim=-1))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
